On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing (#19214)#19214
On-device perf + memory optimizations: custom SDPA, on-the-fly RoPE, KV cache fix, XNNPACK workspace sharing (#19214)#19214leixin wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19214
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New Failures, 1 Unrelated FailureAs of commit 3b254e4 with merge base d767516 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@leixin has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102710062. |
This PR needs a
|
…KV cache fix, XNNPACK workspace sharing (pytorch#19214) Summary: Six changes for the Gemma 4 text decoder + runner, enabled by default. Custom SDPA can be opted out via `--no-use_custom_sdpa` for eager mode or non-XNNPACK backends. Workspace sharing can be opted out via `--noenable_workspace_sharing` for debugging. 1. Custom SDPA — attention now runs through `torch.ops.llama.custom_sdpa` (tiled flash attention from the Llama runner). Skips the 8x KV expansion that GQA/MQA otherwise requires, and never materializes the full `[seq, seq]` attention matrix — the matmul fallback's `[bs, heads, seq, seq]` tensor exceeds S25's 8 MB L2 cache at `seq=2048` and causes severe regression. Adds an inline INT8 dequant path for `Gemma4QuantizedKVCache(return_float_values=False)` that stays inside the XNNPACK partition. 2. On-the-fly RoPE — the attention module stores only the `inv_freq` vector (~128-256 floats) and computes cos/sin per forward, instead of registering precomputed `[max_seq_len, head_dim]` cos/sin buffers. Reduces PTE size 3-7%. 3. KV cache allocation is skipped for `is_kv_shared_layer=True`. In YOCO, 20 of 35 layers consume the donor's KV via `shared_kv` and never write to their own cache, so the allocation was dead. Saves ~40 MB at `seq=1024`, ~80 MB at `seq=2048`. 4. XNNPACK workspace sharing in runner. `Gemma4Runner::load()` now calls `set_option(workspace_sharing_mode_option_key=PerModel, weight_cache_option_key=true)` on the XNNPACK backend before module load. Default-on with `enable_workspace_sharing` constructor flag for opt-out. Without this, real Android/iOS app builds (which don't pass the bench's compile-time `--config xnnpack_workspace_sharing=1`) end up with `Disabled` mode and OOM crash silently on E4B (>2 GB peak memory regression reported by app teams). Compile-time flag in xplat/.../gemma4/targets.bzl (`-DENABLE_XNNPACK_SHARED_WORKSPACE`) is also removed since it was dead — Buck preprocessor flags don't reach `XNNWorkspaceManager.cpp` (which lives in the `xnnpack_backend` compile unit). 5. Correctness fix for KV cache quant + custom SDPA + YOCO. When `Gemma4QuantizedKVCache(return_float_values=False)` is in use, the donor layer now dequants K/V before storing in `kv_to_share` so cross-decoder layers (which lack access to the donor's scales) don't pass raw int8 to `custom_sdpa`. Dormant bug: only triggers with `--quantize_kv_cache --use_custom_sdpa`; previously crashed export with `AssertionError: Expected key to be float32`. 6. iOS VmRSS sscanf fix (consolidates D103030061). `Gemma4Stats::read_rss_kb()` uses `SCNd64` from `<cinttypes>` instead of `%ld` so the format matches `int64_t` on both LP64 (Linux/Android) and LLP64-ish (iOS arm64) platforms. Unblocks iOS sample app builds with `-Werror,-Wformat`. Mask construction is factored into `_build_attn_mask` / `_slice_mask` helpers shared between the custom-SDPA and matmul branches. Differential Revision: D102710062
1e95fc2 to
3b254e4
Compare
Summary:
Six changes for the Gemma 4 text decoder + runner, enabled by default. Custom SDPA can be opted out via
--no-use_custom_sdpafor eager mode or non-XNNPACK backends. Workspace sharing can be opted out via--noenable_workspace_sharingfor debugging.Custom SDPA — attention now runs through
torch.ops.llama.custom_sdpa(tiled flash attention from the Llama runner). Skips the 8x KV expansion that GQA/MQA otherwise requires, and never materializes the full[seq, seq]attention matrix — the matmul fallback's[bs, heads, seq, seq]tensor exceeds S25's 8 MB L2 cache atseq=2048and causes severe regression. Adds an inline INT8 dequant path forGemma4QuantizedKVCache(return_float_values=False)that stays inside the XNNPACK partition.On-the-fly RoPE — the attention module stores only the
inv_freqvector (~128-256 floats) and computes cos/sin per forward, instead of registering precomputed[max_seq_len, head_dim]cos/sin buffers. Reduces PTE size 3-7%.KV cache allocation is skipped for
is_kv_shared_layer=True. In YOCO, 20 of 35 layers consume the donor's KV viashared_kvand never write to their own cache, so the allocation was dead. Saves ~40 MB atseq=1024, ~80 MB atseq=2048.XNNPACK workspace sharing in runner.
Gemma4Runner::load()now callsset_option(workspace_sharing_mode_option_key=PerModel, weight_cache_option_key=true)on the XNNPACK backend before module load. Default-on withenable_workspace_sharingconstructor flag for opt-out. Without this, real Android/iOS app builds (which don't pass the bench's compile-time--config xnnpack_workspace_sharing=1) end up withDisabledmode and OOM crash silently on E4B (>2 GB peak memory regression reported by app teams). Compile-time flag in xplat/.../gemma4/targets.bzl (-DENABLE_XNNPACK_SHARED_WORKSPACE) is also removed since it was dead — Buck preprocessor flags don't reachXNNWorkspaceManager.cpp(which lives in thexnnpack_backendcompile unit).Correctness fix for KV cache quant + custom SDPA + YOCO. When
Gemma4QuantizedKVCache(return_float_values=False)is in use, the donor layer now dequants K/V before storing inkv_to_shareso cross-decoder layers (which lack access to the donor's scales) don't pass raw int8 tocustom_sdpa. Dormant bug: only triggers with--quantize_kv_cache --use_custom_sdpa; previously crashed export withAssertionError: Expected key to be float32.iOS VmRSS sscanf fix (consolidates D103030061).
Gemma4Stats::read_rss_kb()usesSCNd64from<cinttypes>instead of%ldso the format matchesint64_ton both LP64 (Linux/Android) and LLP64-ish (iOS arm64) platforms. Unblocks iOS sample app builds with-Werror,-Wformat.Mask construction is factored into
_build_attn_mask/_slice_maskhelpers shared between the custom-SDPA and matmul branches.Differential Revision: D102710062